# 缓存
import configparser
import os

# 工作目录
import time

import pymysql

CWD = os.getcwd()

# 缓存目录
CACHE_CWD = CWD + '/cache/'

if not os.path.exists(CACHE_CWD):
    os.makedirs(CACHE_CWD)


def map_key(m, k, d=''):
    """
    map 键安全读取
    :param m:
    :param k:
    :param d:
    :return:
    """
    if isinstance(m, dict) and k in m:
        v = m[k]
        v = v if v else d
        return v
    return d


# 生成缓存文件目录
def check_model_cache(name):
    """
    模型缓存存在性检测
    :param name:
    :return:
    """
    vdir = f'{CACHE_CWD}/{name}/'
    if not os.path.exists(vdir):
        os.makedirs(vdir)
    return vdir


def get_model_cache(name, file=None):
    """
    获取模型缓存目录
    :param name:
    :param file:
    :return:
    """
    file = '' if not file else file
    path = f'{CACHE_CWD}/{name}/{file}'
    return path


# 数据缓存
class DatabaseCache:
    # dbConfig dbConfig 数据缓存
    def __init__(self, dbConfig):
        """
        :param dbConfig: 缓存名称，或者dick独享
        """
        self.dbConfig = dbConfig
        self.basedir = None
        self.db_conn = None
        self._struct_basedir = None
        self._sql_basedir = None
        self.name = None
        '''
        dataCache 缓存数据与，文件缓存内的数据结构一直
            config
            [objects] => {
                [object]   = {}
                [c{field}] = {}
            }
        '''
        self.dataCache = {}  # 数据缓存

        # 尝试读取配置
        self.read_name(dbConfig)

    def _create_basedir(self):
        """
        基础文件目录生成
        :return:
        """
        self._struct_basedir = get_model_cache(self.name, 'models')
        self._sql_basedir = get_model_cache(self.name, 'sqls')
        return self

    def read_name(self, name):
        """
        从名称中读取配置文件数据
        :param name:
        :return:
        """
        if name and isinstance(name, str):
            filename = get_model_cache(name, '_config.ini')
            if os.path.exists(filename):
                config = configparser.ConfigParser()
                config.read(filename, encoding='utf-8')

                key = 'database'
                if key in config:
                    self.dbConfig = config[key]
                    self.name = name
                    self.dataCache['config'] = self.dbConfig

    def read_models(self):
        """
        读取缓存目录
        :return:
        """
        vpath = get_model_cache(self.name, 'models')
        if os.path.exists(vpath):
            files = os.listdir(vpath)
            c_objects = {}
            for file in files:
                file_path = f'{vpath}/{file}'
                config = configparser.ConfigParser()
                config.read(file_path, encoding='utf-8')
                dd = config['object']
                c_objects[dd['table_name']] = config
            self.dataCache['object'] = c_objects

    # 链接数缓存其
    #  'host': host,
    #  'user': user,
    #  'password': password,
    #  'db': dbname,
    #  'charset': charset
    #  'name': name 可选
    def cache_conn_ini(self):
        """
        缓存配置文件
        :return:
        """
        dbConfig = self.dbConfig
        name = dbConfig['db']
        if 'name' in dbConfig:
            name = dbConfig['name']

        self.name = name
        self.basedir = check_model_cache(name)
        config = configparser.ConfigParser()
        config['database'] = dbConfig
        self.dataCache['config'] = dbConfig

        filename = self.basedir + '_config.ini'
        with open(filename, 'w', encoding='utf-8') as fh:
            config.write(fh)

    def get_conn(self):
        """
        获取的数据库链接器
        :return:
        """
        if not self.db_conn:
            self.db_conn = pymysql.connect(**self.dbConfig, cursorclass=pymysql.cursors.DictCursor)
        return self.db_conn

    def cache_struct(self):
        """
        结构缓存
        :return:
        """
        conn = self.get_conn()
        if conn:
            self._create_basedir()
            if not os.path.exists(self._struct_basedir):
                os.makedirs(self._struct_basedir)

            c_objects = {}
            with conn.cursor() as cursor:
                # Read a single record
                sql_dir = self._sql_basedir
                if not os.path.exists(sql_dir):
                    os.makedirs(sql_dir)

                sql = "select * from information_schema.TABLES t where TABLE_SCHEMA = DATABASE()"
                cursor.execute(sql)
                result = cursor.fetchall()
                for res in result:
                    if not res:
                        continue

                    v_objects = {}
                    table_name = res['TABLE_NAME']
                    tbfile = f'{self._struct_basedir}/{table_name}.ini'

                    config = configparser.ConfigParser()
                    update_time = ''
                    if res['UPDATE_TIME']:
                        # update_time = time.strftime('%Y-%m-%d %H:%M:%S', res['UPDATE_TIME'])
                        update_time = res['UPDATE_TIME']
                    config['object'] = {
                        'table_name': table_name,
                        'vtype': map_key(res, 'TABLE_TYPE'),
                        'engine': map_key(res, 'ENGINE'),
                        'collation': map_key(res, 'TABLE_COLLATION'),
                        'comment': map_key(res, 'TABLE_COMMENT'),
                        # 'create_time': time.strftime('%Y-%m-%d %H:%M:%S', res['CREATE_TIME']),
                        'create_time': map_key(res, 'CREATE_TIME'),
                        'update_time': update_time,
                    }

                    # **********************
                    # 列信息获取
                    # @todo 使用新的SQL查询: select * from information_schema.`COLUMNS` where TABLE_SCHEMA = database() and TABLE_NAME = '{table_name}';
                    col_sql = f'show columns from `{table_name}`'
                    cursor.execute(col_sql)
                    cols_result = cursor.fetchall()
                    columns = []
                    for col in cols_result:
                        filed = col['Field']
                        ctype = map_key(col, 'Type')
                        cnull = map_key(col, 'Null')
                        ckey = map_key(col, 'Key')
                        cdefault = map_key(col, 'Default')
                        cextra = map_key(col, 'Extra')

                        columns.append(filed)
                        filed_key = f'c_{filed}'
                        config[filed_key] = {
                            'type': ctype,
                            'v_null': cnull,
                            'key': ckey,
                            'default': cdefault,
                            'extra': cextra,
                        }
                        v_objects['object'] = config[filed_key]

                    if len(columns) > 0:
                        config['object']['columns'] = ','.join(columns)
                    # **********************

                    # **********/sql生成************(begin)
                    crt_sql = f'show create table `{table_name}`'
                    cursor.execute(crt_sql)
                    crt_result = cursor.fetchone()
                    if crt_result:
                        p_crt_sql = ''
                        v_key = 'Create View'
                        if v_key in crt_result:
                            p_crt_sql = crt_result[v_key]
                        v_key = 'Create Table'
                        if v_key in crt_result:
                            p_crt_sql = crt_result[v_key]

                        if p_crt_sql != '':
                            with open(f'{sql_dir}/tb_{table_name}.sql', 'w', encoding='utf-8') as crt_result_fh:
                                crt_result_fh.write(p_crt_sql)

                    # **********/sql生成************(end)
                    time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))

                    v_objects['object'] = config['object']
                    with open(tbfile, 'w', encoding='utf-8') as fh:
                        config.write(fh)
                    c_objects[table_name] = v_objects

            self.dataCache['object'] = c_objects

    def is_success(self):
        if not self.dbConfig:
            return False, '参数无效'
        elif isinstance(self.dbConfig, str):
            return False, f'{self.dbConfig} 无此连接缓存项目'

        return True, ''

    def get_sql(self, table_name):
        """
        获取表格SQL语句
        :param table_name:
        :return:
        """
        self._create_basedir()
        sql_file = f'{self._sql_basedir}/tb_{table_name}.sql'
        if os.path.exists(sql_file):
            with open(sql_file, 'r', encoding='utf-8') as fh:
                return fh.read()
        return None

    @staticmethod
    def list():
        """
        获取缓存列表
        :return:
        """
        cache_list = []
        files = os.listdir(CACHE_CWD)
        for file in files:
            vpath = f'{CACHE_CWD}/{file}'
            if os.path.isdir(vpath):
                cfg_path = f'{vpath}/_config.ini'
                if os.path.exists(cfg_path):
                    cache_list.append(file)

        return cache_list


def _array_to_map(arr):
    """
    数组转map
    :param arr: list
    :return:
    """
    arr = arr if isinstance(arr, list) else []
    am = {}
    for a in arr:
        am[a] = True
    return am


def _tow_arr_to_map(a1, a2):
    """
    两个数组并行生成map
    :param a1:
    :param a2:
    :return:
    """
    a1 = a1 if isinstance(a1, list) else []
    a2 = a2 if isinstance(a2, list) else []
    c1_len = len(a1)
    c2_len = len(a2)
    max_len = c1_len if c1_len > c2_len else c2_len

    m1, m2 = {}, {}
    i = 0
    while i < max_len:
        if i < c1_len:
            m1[a1[i]] = True
        if i < c2_len:
            m2[a2[i]] = True

        i += 1

    return m1, m2


class DiffCache:
    def __init__(self):
        """
        缓存文件对比
        :param name1:
        :param name2:
        """
        self.names = []

    def diff(self, name1, name2):
        """
        项目对比
        :param name1:
        :param name2:
        :return:
        """
        self.names = [name1, name2]
        dc1, dc2 = DatabaseCache(name1), DatabaseCache(name2)

        # 存在性检测
        scs_mk, msg = dc1.is_success()
        if not scs_mk:
            print(f" {name1} 不存在，{msg}")
            return

        scs_mk, msg = dc2.is_success()
        if not scs_mk:
            print(f" {name2} 不存在，{msg}")
            return

        dc1.read_models()
        dc2.read_models()

        dc1_object = dc1.dataCache['object'] if dc1.dataCache and 'object' in dc1.dataCache else {}
        dc2_object = dc2.dataCache['object'] if dc2.dataCache and 'object' in dc2.dataCache else {}

        # 列级比较
        # dc1 -> c1; dc2 -> c2
        def compare_columns(tb, c1, c2):
            cm1, cm2 = _tow_arr_to_map(c1, c2)

            cdiff = {'new': [], 'del': []}

            # 删除，修改
            for k in cm1:
                # 删除字段
                if not k in cm2:
                    cdiff['new'].append(k)

            for k in cm2:
                # 新增字段
                if not k in cm1:
                    cdiff['new'].append(k)

            col_desc = ''
            if len(cdiff['new']) > 0:
                col_desc = '+[' + ','.join(cdiff['new']) + ']'
            if len(cdiff['del']) > 0:
                col_desc = '-[' + ','.join(cdiff['del']) + ']'
            if col_desc:
                print(f' M~   {k}')
                print(f'        column: {col_desc}')

        # 对象数比较
        def compare_objects():
            for k in dc1_object:
                # 删除
                if not k in dc2_object:
                    print(f' D-   {k}')
                else:
                    dd1 = dc1_object[k]
                    dd2 = dc2_object[k]

                    c1 = dd1['object']['columns'].split(',')
                    c2 = dd2['object']['columns'].split(',')
                    compare_columns(k, c1, c2)

            # 新增
            for k in dc2_object:
                if not k in dc1_object:
                    print(f' A+   {k}')

        # 执行对比
        compare_objects()

    def patch(self, name1, name2):
        """
        生成 sql 补丁
        :return:
        """
        self.names = [name1, name2]
        dc1, dc2 = DatabaseCache(name1), DatabaseCache(name2)

        # 存在性检测
        scs_mk, msg = dc1.is_success()
        if not scs_mk:
            print(f" {name1} 不存在，{msg}")
            return

        scs_mk, msg = dc2.is_success()
        if not scs_mk:
            print(f" {name2} 不存在，{msg}")
            return

        dc1.read_models()
        dc2.read_models()

        dc1_object = dc1.dataCache['object'] if dc1.dataCache and 'object' in dc1.dataCache else {}
        dc2_object = dc2.dataCache['object'] if dc2.dataCache and 'object' in dc2.dataCache else {}

        patch_sqls = []

        # 列级比较
        # dc1 -> c1; dc2 -> c2
        def compare_columns(tb, c1, c2):
            cm1, cm2 = _tow_arr_to_map(c1, c2)

            # 删除，修改
            for k in cm1:
                # 删除字段
                if not k in cm2:
                    col_drop = f'alter table `{tb}` drop column `{k}`'
                    patch_sqls.append(col_drop)

            for k in cm2:
                # 新增字段
                if not k in cm1:
                    cols_attr = dc2_object[tb]
                    ckey = f'c_{k}'

                    if ckey in cols_attr:
                        cols_attr = cols_attr[ckey]
                        v_type = cols_attr['type']
                        extra = cols_attr['extra']
                        default = cols_attr['default']
                        attr_plus = []
                        # 空判断
                        if cols_attr['v_null'] == 'NO':
                            attr_plus.append('not null')
                        # 默认值
                        if default:
                            if 'char' in v_type or 'text' in v_type:
                                attr_plus.append(f"default '{default}'")
                            else:
                                attr_plus.append(f"default {default}")

                        # 附加
                        if extra:
                            attr_plus.append(extra)

                        # 其他属性
                        attr_plus = ' '.join(attr_plus) if len(attr_plus) > 0 else ''
                        attr_plus = ' ' + attr_plus if attr_plus != '' else attr_plus
                        col_drop = f'alter table `{tb}` add `{k}` {v_type}{attr_plus}'
                        patch_sqls.append(col_drop)

        # 对象数比较
        def compare_objects():
            for k in dc1_object:
                # 删除
                if not k in dc2_object:
                    tb_drop = f'Drop table if exists `{k}`'
                    patch_sqls.append(tb_drop)
                else:
                    dd1 = dc1_object[k]
                    dd2 = dc2_object[k]

                    c1 = dd1['object']['columns'].split(',')
                    c2 = dd2['object']['columns'].split(',')
                    compare_columns(k, c1, c2)

            # 新增
            for k in dc2_object:
                if not k in dc1_object:
                    tb_create = dc2.get_sql(k)
                    if tb_create:
                        patch_sqls.append(tb_create)

        compare_objects()

        if len(patch_sqls) > 0:
            patch_name = f'{CACHE_CWD}{name1}_vs_{name2}_patch.sql'
            with open(patch_name, 'w', encoding='utf-8') as fh:
                fh.write(";\r\n".join(patch_sqls) + ";")
                print(f' patch已生成: {patch_name}')
        else:
            print(' 未生成，patch 脚本！可能数据库结构无差别')
